function net = cnn_init_classification(imdb, lambda,dataset)  %_3L
%

rng('default');
rng(0);
% input size = 10x20x3
% constant scalar for the random initial network weights. 
f=1/100; 
%net.layers{id}.learningRate = [0, 0]
bias = 0.01;
net.layers = {} ;

% if(contains(dataset,'mnist') || contains(dataset,'cifar'))
% if(contains(dataset,'mnist'))
%     nb_FM_firstLayer = 1; %FM=featuremaps
%     nb_FM_lastLayer = 4; %FM=featuremaps
% else %|| contains(dataset,'cifar')
%     nb_FM_firstLayer = 3;
%     nb_FM_lastLayer = 5; %FM=featuremaps
% 
% end
if(contains(dataset,'mnist'))
net.layers{end+1} = struct('type', 'conv', ...
                           'weights', {{f*randn(5,5,1,20, 'single'), zeros(1, 20, 'single')}}, ...
                           'biases', bias*ones(1, 20, 'single'), ...
                           'stride', 1, ...
                           'pad', 0, ...
                           'filtersLearningRate', 1, ...
                           'biasesLearningRate', 2, ...
                           'filtersWeightDecay', 1, ...
                           'biasesWeightDecay', 0, ...
                           'name', 'conv1') ;
net.layers{end+1} = struct('type', 'relu') ;                       
net.layers{end+1} = struct('type', 'pool', ...
                           'method', 'max', ...
                           'pool', [2 2], ...
                           'stride', 2, ...
                           'pad', 0) ; %[7 7] st=7
%% --------------- added --------------
net.layers{end+1} = struct('type', 'conv', ...
                           'weights', {{f*randn(5,5,20,50, 'single'), zeros(1, 50, 'single')}}, ...
                           'biases', bias*ones(1, 50, 'single'), ...
                           'stride', 1, ...
                           'pad', 0, ...
                           'filtersLearningRate', 1, ...
                           'biasesLearningRate', 2, ...
                           'filtersWeightDecay', 1, ...
                           'biasesWeightDecay', 0, ...
                           'name', 'conv2') ;
net.layers{end+1} = struct('type', 'relu') ;     
net.layers{end+1} = struct('type', 'pool', ...
                           'method', 'max', ...
                           'pool', [2 2], ...
                           'stride', 2, ...
                           'pad', 0) ; %[7 7] st=7
%% -------------------------------  
net.layers{end+1} = struct('type', 'conv', ...
                           'weights', {{f*randn(4 ,4 ,50,500, 'single'), zeros(1, 500, 'single')}}, ...
                           'biases', bias*ones(1,500,'single'), ...
                           'stride', 1, ...
                           'pad', 0, ...
                           'filtersLearningRate', 1, ...
                           'biasesLearningRate', 2, ...
                           'filtersWeightDecay', 1, ...
                           'biasesWeightDecay', 0, ...
                           'name', 'fc2') ;
net.layers{end+1} = struct('type', 'relu') ;   
net.layers{end + 1} = struct('type','dropout','rate',0.5);
%% -------------------------------                       
net.layers{end+1} = struct('type', 'conv', ...
                           'weights', {{f*randn(1,1,500,10, 'single'), zeros(1, 10, 'single')}}, ...
                            'biases', bias*ones(1, 10, 'single'), ...   
                            'stride', 1, ...
                           'pad', 0, ...
                           'filtersLearningRate', 1, ...
                           'biasesLearningRate', 2, ...
                           'filtersWeightDecay', 1, ...
                           'biasesWeightDecay', 0, ...
                           'name', 'fc3') ;
elseif(contains(dataset,'cifar'))
    net.layers{end+1} = struct('type', 'conv', ...
                           'weights', {{f*randn(3,3,3,48, 'single'), zeros(1, 48, 'single')}}, ...
                           'biases', bias*ones(1, 48, 'single'), ...
                           'stride', 1, ...
                           'pad', 1, ...
                           'filtersLearningRate', 1, ...
                           'biasesLearningRate', 2, ...
                           'filtersWeightDecay', 1, ...
                           'biasesWeightDecay', 0, ...
                           'name', 'conv1') ;
net.layers{end+1} = struct('type', 'relu') ;  
    net.layers{end+1} = struct('type', 'conv', ...
                           'weights', {{f*randn(3,3,48,48, 'single'), zeros(1, 48, 'single')}}, ...
                           'biases', bias*ones(1, 48, 'single'), ...
                           'stride', 1, ...
                           'pad', 0, ...
                           'filtersLearningRate', 1, ...
                           'biasesLearningRate', 2, ...
                           'filtersWeightDecay', 1, ...
                           'biasesWeightDecay', 0, ...
                           'name', 'conv1') ;
net.layers{end+1} = struct('type', 'relu') ;
net.layers{end+1} = struct('type', 'pool', ...
                           'method', 'max', ...
                           'pool', [2 2], ...
                           'stride', 2, ...
                           'pad', 0) ; %[7 7] st=7
%% ---------------  --------------
net.layers{end+1} = struct('type', 'conv', ...
                           'weights', {{f*randn(3,3,48,96, 'single'), zeros(1, 96, 'single')}}, ...
                           'biases', bias*ones(1, 96, 'single'), ...
                           'stride', 1, ...
                           'pad', 1, ...
                           'filtersLearningRate', 1, ...
                           'biasesLearningRate', 2, ...
                           'filtersWeightDecay', 1, ...
                           'biasesWeightDecay', 0, ...
                           'name', 'conv2') ;
net.layers{end+1} = struct('type', 'relu') ; 
net.layers{end+1} = struct('type', 'conv', ...
                           'weights', {{f*randn(3,3,96,96, 'single'), zeros(1, 96, 'single')}}, ...
                           'biases', bias*ones(1, 96, 'single'), ...
                           'stride', 1, ...
                           'pad', 0, ...
                           'filtersLearningRate', 1, ...
                           'biasesLearningRate', 2, ...
                           'filtersWeightDecay', 1, ...
                           'biasesWeightDecay', 0, ...
                           'name', 'conv2') ;
net.layers{end+1} = struct('type', 'relu') ;     
net.layers{end+1} = struct('type', 'pool', ...
                           'method', 'max', ...
                           'pool', [2 2], ...
                           'stride', 2, ...
                           'pad', 0) ; %[7 7] st=7
%% --------------- added --------------
net.layers{end+1} = struct('type', 'conv', ...
                           'weights', {{f*randn(3,3,96,192, 'single'), zeros(1, 192, 'single')}}, ...
                           'biases', bias*ones(1, 192, 'single'), ...
                           'stride', 1, ...
                           'pad', 1, ...
                           'filtersLearningRate', 1, ...
                           'biasesLearningRate', 2, ...
                           'filtersWeightDecay', 1, ...
                           'biasesWeightDecay', 0, ...
                           'name', 'conv2') ;
net.layers{end+1} = struct('type', 'relu') ; 
net.layers{end+1} = struct('type', 'conv', ...
                           'weights', {{f*randn(3,3,192,192, 'single'), zeros(1, 192, 'single')}}, ...
                           'biases', bias*ones(1, 192, 'single'), ...
                           'stride', 1, ...
                           'pad', 0, ...
                           'filtersLearningRate', 1, ...
                           'biasesLearningRate', 2, ...
                           'filtersWeightDecay', 1, ...
                           'biasesWeightDecay', 0, ...
                           'name', 'conv2') ;
net.layers{end+1} = struct('type', 'relu') ;     
net.layers{end+1} = struct('type', 'pool', ...
                           'method', 'max', ...
                           'pool', [2 2], ...
                           'stride', 2, ...
                           'pad', 0) ; %[7 7] st=7
%% -------------------------------  
net.layers{end+1} = struct('type', 'conv', ...
                           'weights', {{f*randn(1 ,1 ,192,512, 'single'), zeros(1, 512, 'single')}}, ...
                           'biases', bias*ones(1,512,'single'), ...
                           'stride', 1, ...
                           'pad', 0, ...
                           'filtersLearningRate', 1, ...
                           'biasesLearningRate', 2, ...
                           'filtersWeightDecay', 1, ...
                           'biasesWeightDecay', 0, ...
                           'name', 'fc2') ;
net.layers{end+1} = struct('type', 'relu') ;   
net.layers{end + 1} = struct('type','dropout','rate',0.5);
%% -------------------------------                       
net.layers{end+1} = struct('type', 'conv', ...
                           'weights', {{f*randn(2,2,512,10, 'single'), zeros(1, 10, 'single')}}, ...
                            'biases', bias*ones(1, 10, 'single'), ...   
                            'stride', 1, ...
                           'pad', 0, ...
                           'filtersLearningRate', 1, ...
                           'biasesLearningRate', 2, ...
                           'filtersWeightDecay', 1, ...
                           'biasesWeightDecay', 0, ...
                           'name', 'fc3') ;
else % all other datasets
    nb_features = size(imdb.images.data, 1);
    switch dataset
        case 'ionosphere'
            nb_filters = [500;50;max(imdb.images.labels)]; %for ionosphere
        case 'WP_Breast_Cancer'
            nb_filters = [500;50;max(imdb.images.labels)]; %for WB cancer
        case 'SPECTF_Heart'
            nb_filters = [650;65;max(imdb.images.labels)]; %for SPECTF_Heart
        case 'pid'
            nb_filters = [200;20;max(imdb.images.labels)]; %for pid
        case 'glass'
            nb_filters = [200;100;max(imdb.images.labels)]; %for glass %nb_label=6/att=10 AND yeat nb_label=10/att=8
        case 'yeast_8l'
            nb_filters = [200;100;max(imdb.images.labels)]; %for glass %nb_label=6/att=10 AND yeat nb_label=10/att=8
        case 'car'
            nb_filters = [150;75;max(imdb.images.labels)]; %for car %nb_label=4/att=6
        case 'satimage'
            nb_filters = [600; 100;max(imdb.images.labels)]; %for satimage %nb_label=6/att=36
        case 'thyroid'
            nb_filters = [350; 75;max(imdb.images.labels)]; %for thyroid  %nb_label=3/att=21
        otherwise
            disp('wrong dataset')
     end
    %% -------------------------------  
    net.layers{end+1} = struct('type', 'conv', ...
                               'weights', {{f*randn(nb_features,1,1,nb_filters(1), 'single'), zeros(1, nb_filters(1), 'single')}}, ...
                               'biases', bias*ones(1,nb_filters(1),'single'), ...
                               'stride', 1, ...
                               'pad', 0, ...
                               'filtersLearningRate', 1, ...
                               'biasesLearningRate', 2, ...
                               'filtersWeightDecay', 1, ...
                               'biasesWeightDecay', 0, ...
                               'name', 'fc1') ;                 %res(8)
    net.layers{end+1} = struct('type', 'relu') ;                %res(9)

    %% -------------------------------  
    net.layers{end+1} = struct('type', 'conv', ...
                               'weights', {{f*randn(1,1,nb_filters(1),nb_filters(2), 'single'), zeros(1, nb_filters(2), 'single')}}, ...
                               'biases', bias*ones(1,nb_filters(2),'single'), ...
                               'stride', 1, ...
                               'pad', 0, ...
                               'filtersLearningRate', 1, ...
                               'biasesLearningRate', 2, ...
                               'filtersWeightDecay', 1, ...
                               'biasesWeightDecay', 0, ...
                               'name', 'fc2') ;                 %res(8)
    net.layers{end+1} = struct('type', 'relu') ;    
    %% -------------------------------  
    net.layers{end+1} = struct('type', 'conv', ...
                               'weights', {{f*randn(1,1,nb_filters(2),nb_filters(3), 'single'), zeros(1, nb_filters(3), 'single')}}, ...
                               'biases', bias*ones(1,nb_filters(3),'single'), ...
                               'stride', 1, ...
                               'pad', 0, ...
                               'filtersLearningRate', 1, ...
                               'biasesLearningRate', 2, ...
                               'filtersWeightDecay', 1, ...
                               'biasesWeightDecay', 0, ...
                               'name', 'fc3') ;                 %res(8)
end
%% -------------------------------

% Loss layer
%net.layers{end+1} = struct('type', 'relu') ; 
if lambda == 0
    net.layers{end+1} = struct('type', 'softmaxloss') ;
elseif lambda ==1 %lambda ==2 
    net.layers{end+1} = struct('type', 'softmaxloss_relevance') ;
elseif lambda ==2 %lambda ==2 
    net.layers{end+1} = struct('type', 'mshinge') ;
elseif lambda ==3
    net.layers{end+1} = struct('type', 'mshinge_relevance') ;
elseif lambda ==4
    net.layers{end+1} = struct('type', 'euclidean') ;
elseif lambda ==5
    net.layers{end+1} = struct('type', 'euclidean_relevance_prod') ;
elseif lambda ==6
    net.layers{end+1} = struct('type', 'euclidean_relevance_prodsup1') ;
elseif lambda ==7
    net.layers{end+1} = struct('type', 'euclidean_relevance_sum') ;
elseif lambda ==8
    net.layers{end+1} = struct('type', 'mshingesquared') ;
elseif lambda ==9
    net.layers{end+1} = struct('type', 'mshingesquared_relevance') ;    
elseif lambda ==10
    net.layers{end+1} = struct('type', 'sigmoid') ;   %% USED FOR EUCLIDEAN LOSS ONLY see http://rohanvarma.me/Loss-Functions/ 
    net.layers{end+1} = struct('type', 'euclidean') ;    
elseif lambda ==11
    net.layers{end+1} = struct('type', 'sigmoid') ;   %% USED FOR EUCLIDEAN LOSS ONLY see http://rohanvarma.me/Loss-Functions/ 
    net.layers{end+1} = struct('type', 'euclidean_relevance_prodsup1') ;
elseif lambda ==12
    net.layers{end+1} = struct('type', 'mshingecubed') ;
elseif lambda ==13
    net.layers{end+1} = struct('type', 'mshingecubed_relevance') ;    
end

% Visualize the network
vl_simplenn_display(net, 'inputSize', [size(imdb.images.data, 1) size(imdb.images.data, 2) 1 10])